{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "# 获取项目根目录的绝对路径\n",
    "current_dir = os.path.dirname(os.path.abspath(''))\n",
    "\n",
    "# 确保项目根目录在系统路径中\n",
    "if current_dir not in sys.path:\n",
    "    sys.path.append(current_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "\n",
    "from student.output import DeepResNet\n",
    "from data.distill_dataset import DistillDataManager\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 加载数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[4.2537, 7.2491, 6.1717, 9.2729, 4.9853, 9.8473, 3.1820, 9.3614, 8.1104,\n",
      "         9.0288],\n",
      "        [2.3149, 8.7649, 5.0795, 8.0234, 4.8227, 8.7122, 0.7336, 2.7354, 7.6845,\n",
      "         6.5359],\n",
      "        [6.5778, 9.8060, 0.8749, 7.8437, 6.0056, 7.5152, 8.9349, 6.0827, 4.3046,\n",
      "         7.7731],\n",
      "        [0.3096, 1.4307, 0.4254, 2.0687, 3.2475, 0.8592, 9.9883, 3.1064, 7.1883,\n",
      "         2.8260],\n",
      "        [6.0670, 4.3900, 8.1416, 5.1876, 1.6904, 2.8222, 9.9509, 8.0686, 6.5593,\n",
      "         5.9664],\n",
      "        [9.7981, 5.3376, 9.5862, 0.3684, 5.5181, 7.1683, 3.2136, 5.2784, 1.8340,\n",
      "         9.4031],\n",
      "        [2.1909, 0.4959, 9.4064, 1.8535, 2.7756, 7.6035, 9.4730, 6.0837, 7.2479,\n",
      "         1.7767],\n",
      "        [2.4067, 8.6760, 5.0395, 6.0531, 6.9528, 0.0580, 0.6878, 2.1934, 4.4764,\n",
      "         4.8344],\n",
      "        [4.9628, 6.8850, 0.1238, 5.7276, 2.3660, 0.6239, 9.2926, 6.6920, 3.6652,\n",
      "         1.5679],\n",
      "        [2.8365, 5.7207, 1.2948, 2.9172, 4.4617, 8.5791, 4.1536, 8.3898, 6.0025,\n",
      "         7.9040]])\n"
     ]
    }
   ],
   "source": [
    "radius_save_path = \"../data/radius_matrix.pth\"\n",
    "\n",
    "if os.path.exists(radius_save_path):\n",
    "    radius_matrix = torch.load(radius_save_path, weights_only=False)\n",
    "else:\n",
    "    radius_matrix = torch.rand(10, 10) * 10\n",
    "    torch.save(radius_matrix, radius_save_path)\n",
    "\n",
    "print(radius_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_dataset loaded.\n",
      "test_dataset loaded.\n",
      "<data.distill_dataset.data_manager.FDTDDataset object at 0x76c0e13ed090>\n",
      "<data.distill_dataset.data_manager.FDTDDataset object at 0x76c092fdb850>\n"
     ]
    }
   ],
   "source": [
    "data_manager = DistillDataManager(radius_matrix)\n",
    "\n",
    "train_dataset = data_manager.get_train_dataset(mode='output')\n",
    "print(\"train_dataset loaded.\")\n",
    "test_dataset = data_manager.get_test_dataset(mode='output')\n",
    "print(\"test_dataset loaded.\")\n",
    "\n",
    "print(train_dataset)\n",
    "print(test_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "model = DeepResNet(\n",
    "    input_dim=10,\n",
    "    output_dim=10,\n",
    "    hidden_dim=2048,\n",
    "    num_blocks=8\n",
    ").to(device)\n",
    "\n",
    "model = data_manager.load_model(model, mode='output')\n",
    "\n",
    "num_epochs = 100000\n",
    "learning_rate = 1e-3\n",
    "batch_size = 64\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
    "criterion = torch.nn.MSELoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 68, Loss: 2.0323, LastEpochLoss: 1.5648:   0%|          | 1281/1900000 [00:50<21:07:18, 24.97it/s] "
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[6], line 23\u001b[0m\n\u001b[1;32m     21\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[1;32m     22\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mstep()\n\u001b[0;32m---> 23\u001b[0m epoch_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mitem\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     24\u001b[0m progress_bar\u001b[38;5;241m.\u001b[39mupdate(\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m     26\u001b[0m progress_bar\u001b[38;5;241m.\u001b[39mset_description(\n\u001b[1;32m     27\u001b[0m     \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, Loss: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mloss\u001b[38;5;241m.\u001b[39mitem()\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \n\u001b[1;32m     28\u001b[0m     \u001b[38;5;241m+\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m, LastEpochLoss: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mlast_epoch_loss\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m     29\u001b[0m )\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "train_loss = []\n",
    "best_loss = float('inf')\n",
    "best_state_dict = None\n",
    "best_epochs = 0\n",
    "total_steps = num_epochs * len(train_loader)\n",
    "progress_bar = tqdm(total=total_steps)\n",
    "last_epoch_loss = 0\n",
    "\n",
    "with torch.enable_grad():\n",
    "    for epoch in range(num_epochs):\n",
    "        epoch_loss = 0\n",
    "        model.train()\n",
    "        for inputs, labels in train_loader:\n",
    "\n",
    "            inputs = inputs.to(device)\n",
    "            labels = labels.to(device)\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(inputs)\n",
    "            loss = criterion(outputs, labels)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            epoch_loss += loss.item()\n",
    "            progress_bar.update(1)\n",
    "\n",
    "            progress_bar.set_description(\n",
    "                f\"Epoch {epoch+1}, Loss: {loss.item():.4f}\" \n",
    "                + f\", LastEpochLoss: {last_epoch_loss:.4f}\"\n",
    "            )\n",
    "            \n",
    "        last_epoch_loss = epoch_loss / len(train_loader)\n",
    "        train_loss.append(last_epoch_loss)\n",
    "\n",
    "        if last_epoch_loss < best_loss:\n",
    "            best_loss = last_epoch_loss\n",
    "            best_state_dict = model.state_dict()\n",
    "            best_epochs = epoch\n",
    "\n",
    "progress_bar.close()\n",
    "print(f\"Best Epochs: {best_epochs}; Best Loss: {best_loss:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if best_state_dict is not None:\n",
    "    model.load_state_dict(best_state_dict)\n",
    "\n",
    "data_manager.save_model(model, mode='output')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 测试模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = DeepResNet(\n",
    "    input_dim=10,\n",
    "    output_dim=10,\n",
    "    hidden_dims=[512, 1024, 2048, 1024, 512]\n",
    ")\n",
    "\n",
    "model = data_manager.load_model(model, mode='output')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "\n",
    "test_loss = 0\n",
    "\n",
    "all_preds = []\n",
    "all_labels = []\n",
    "\n",
    "with torch.no_grad():\n",
    "    for inputs, labels in test_loader:\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        test_loss += loss.item()\n",
    "        all_preds.append(outputs.cpu().numpy())\n",
    "        all_labels.append(labels.cpu().numpy())\n",
    "\n",
    "test_loss /= len(test_loader)\n",
    "print(f\"Test Loss: {test_loss:.6f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 将预测和标签转换为numpy数组以便绘图\n",
    "all_preds = np.concatenate(all_preds, axis=0)  # [n_samples, n_ports]\n",
    "all_labels = np.concatenate(all_labels, axis=0)  # [n_samples, n_ports]\n",
    "\n",
    "# 创建图形\n",
    "plt.figure(figsize=(15, 5))\n",
    "\n",
    "# 1. 样本预测对比图\n",
    "plt.subplot(131)\n",
    "sample_idx = 0  # 可以改变这个索引来查看不同的样本\n",
    "plt.plot(all_preds[sample_idx], 'b-o', label='预测值')\n",
    "plt.plot(all_labels[sample_idx], 'r-o', label='真实值')\n",
    "plt.title(f'样本 {sample_idx} 的预测结果')\n",
    "plt.xlabel('端口')\n",
    "plt.ylabel('强度')\n",
    "plt.legend()\n",
    "plt.grid(True)\n",
    "\n",
    "# 2. 预测值与真实值的散点图\n",
    "plt.subplot(132)\n",
    "plt.scatter(all_labels.flatten(), all_preds.flatten(), alpha=0.5)\n",
    "plt.plot([0, 1], [0, 1], 'r--')  # 理想的对角线\n",
    "plt.title('预测值 vs 真实值')\n",
    "plt.xlabel('真实值')\n",
    "plt.ylabel('预测值')\n",
    "plt.grid(True)\n",
    "\n",
    "# 3. 每个端口的误差箱型图\n",
    "plt.subplot(133)\n",
    "errors = all_preds - all_labels\n",
    "plt.boxplot(errors)\n",
    "plt.title('各端口预测误差分布')\n",
    "plt.xlabel('端口')\n",
    "plt.ylabel('误差')\n",
    "plt.grid(True)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# 计算统计指标\n",
    "mse = np.mean((all_preds - all_labels) ** 2)\n",
    "mae = np.mean(np.abs(all_preds - all_labels))\n",
    "r2 = 1 - np.sum((all_preds - all_labels) ** 2) / np.sum((all_labels - np.mean(all_labels)) ** 2)\n",
    "\n",
    "print(\"整体性能指标:\")\n",
    "print(f\"均方误差 (MSE): {mse:.6f}\")\n",
    "print(f\"平均绝对误差 (MAE): {mae:.6f}\")\n",
    "print(f\"R² 分数: {r2:.6f}\")\n",
    "\n",
    "# 计算每个端口的性能指标\n",
    "print(\"\\n各端口性能指标:\")\n",
    "for port in range(all_preds.shape[1]):\n",
    "    port_mse = np.mean((all_preds[:, port] - all_labels[:, port]) ** 2)\n",
    "    port_mae = np.mean(np.abs(all_preds[:, port] - all_labels[:, port]))\n",
    "    port_r2 = 1 - np.sum((all_preds[:, port] - all_labels[:, port]) ** 2) / \\\n",
    "              np.sum((all_labels[:, port] - np.mean(all_labels[:, port])) ** 2)\n",
    "    print(f\"\\n端口 {port}:\")\n",
    "    print(f\"  MSE: {port_mse:.6f}\")\n",
    "    print(f\"  MAE: {port_mae:.6f}\")\n",
    "    print(f\"  R²: {port_r2:.6f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "OpTorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
